{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from __future__ import print_function\n",
    "import argparse\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torch.optim as optim\n",
    "from torchvision import datasets, transforms, models\n",
    "from torch.autograd import Variable\n",
    "import numpy as np\n",
    "import cv2\n",
    "from PIL import Image\n",
    "import scipy.misc\n",
    "import matplotlib.pyplot as plt\n",
    "import torchvision\n",
    "import sys\n",
    "sys.path.append('cifar/')\n",
    "from model import cifar10\n",
    "import torch.optim as optim"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Training settings\n",
    "data = '/home/user/datasets/Traffic_Sign/gtsrb-german-traffic-sign'\n",
    "batch_size = 64\n",
    "epochs = 100\n",
    "lr = 0.0001\n",
    "seed = 1\n",
    "log_interval = 10\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "from data import initialize_data, data_transforms,data_jitter_hue,data_jitter_brightness,data_jitter_saturation,data_jitter_contrast,data_rotate,data_hvflip,data_shear,data_translate,data_center,data_hflip,data_vflip # data.py in the same folder"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "trainset = torch.utils.data.ConcatDataset([datasets.ImageFolder(data + '/train_images',\n",
    "   transform=data_transforms),\n",
    "   datasets.ImageFolder(data + '/train_images',\n",
    "   transform=data_jitter_brightness),datasets.ImageFolder(data + '/train_images',\n",
    "   transform=data_jitter_hue),datasets.ImageFolder(data + '/train_images',\n",
    "   transform=data_jitter_contrast),datasets.ImageFolder(data + '/train_images',\n",
    "   transform=data_jitter_saturation),datasets.ImageFolder(data + '/train_images',\n",
    "   transform=data_translate),datasets.ImageFolder(data + '/train_images',\n",
    "   transform=data_rotate),datasets.ImageFolder(data + '/train_images',\n",
    "   transform=data_hvflip),datasets.ImageFolder(data + '/train_images',\n",
    "   transform=data_center),datasets.ImageFolder(data + '/train_images',\n",
    "   transform=data_shear)])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 56,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "353390\n"
     ]
    }
   ],
   "source": [
    "print(len(trainset))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 58,
   "metadata": {},
   "outputs": [],
   "source": [
    "x_train = [x[0] for x in trainset]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 67,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "353390"
      ]
     },
     "execution_count": 67,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(x_train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 68,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[[-0.9553, -0.9553, -0.9553,  ..., -0.9407, -0.9553, -0.9847],\n",
       "         [-0.9407, -0.9700, -0.9700,  ..., -0.9553, -0.9700, -0.9994],\n",
       "         [-0.9553, -0.9700, -0.9553,  ..., -0.9700, -0.9700, -0.9847],\n",
       "         ...,\n",
       "         [-0.7499, -0.2655,  0.3949,  ...,  0.7031, -0.0601, -0.6031],\n",
       "         [-0.7205, -0.2362,  0.3802,  ...,  0.6884, -0.0454, -0.6325],\n",
       "         [-0.6618, -0.2362,  0.2335,  ...,  0.4683, -0.2068, -0.6618]],\n",
       "\n",
       "        [[-0.8738, -0.8738, -0.8585,  ..., -0.8891, -0.8891, -0.9044],\n",
       "         [-0.8585, -0.8891, -0.8738,  ..., -0.8891, -0.9044, -0.9197],\n",
       "         [-0.8738, -0.8891, -0.8585,  ..., -0.8891, -0.8891, -0.9044],\n",
       "         ...,\n",
       "         [-0.6291, -0.2773,  0.2427,  ...,  0.3651, -0.1550, -0.4762],\n",
       "         [-0.6138, -0.2620,  0.1815,  ...,  0.2580, -0.2009, -0.5373],\n",
       "         [-0.5832, -0.2314,  0.1509,  ...,  0.1968, -0.3079, -0.5832]],\n",
       "\n",
       "        [[-0.8929, -0.8929, -0.8780,  ..., -0.9078, -0.9078, -0.9227],\n",
       "         [-0.8780, -0.9078, -0.8929,  ..., -0.8929, -0.9078, -0.9377],\n",
       "         [-0.8929, -0.9078, -0.8780,  ..., -0.9078, -0.9078, -0.9227],\n",
       "         ...,\n",
       "         [-0.6841, -0.3708,  0.1363,  ...,  0.2706, -0.2068, -0.5200],\n",
       "         [-0.6841, -0.3858,  0.0767,  ...,  0.1811, -0.2515, -0.5647],\n",
       "         [-0.6841, -0.4007,  0.0319,  ...,  0.1512, -0.3410, -0.6095]]])"
      ]
     },
     "execution_count": 68,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x_train[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_loader = torch.utils.data.DataLoader(\n",
    "   trainset, batch_size=4, shuffle=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([4, 3, 32, 32])\n",
      "torch.Size([4])\n",
      "torch.Size([4, 3, 32, 32])\n",
      "torch.Size([4])\n",
      "torch.Size([4, 3, 32, 32])\n",
      "torch.Size([4])\n",
      "torch.Size([4, 3, 32, 32])\n",
      "torch.Size([4])\n",
      "torch.Size([4, 3, 32, 32])\n",
      "torch.Size([4])\n",
      "torch.Size([4, 3, 32, 32])\n",
      "torch.Size([4])\n",
      "torch.Size([4, 3, 32, 32])\n",
      "torch.Size([4])\n",
      "torch.Size([4, 3, 32, 32])\n",
      "torch.Size([4])\n",
      "torch.Size([4, 3, 32, 32])\n",
      "torch.Size([4])\n",
      "torch.Size([4, 3, 32, 32])\n",
      "torch.Size([4])\n"
     ]
    }
   ],
   "source": [
    "for i, data in enumerate(train_loader):\n",
    "    if i < 10: # 10 x batch_size images\n",
    "        x_train, y_train = data\n",
    "        print(x_train.shape)\n",
    "        print(y_train.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "metadata": {},
   "outputs": [],
   "source": [
    "mean = (0.3337, 0.3064, 0.3171)\n",
    "\n",
    "std = (0.2672, 0.2564, 0.2629)\n",
    "def unnormalize(y, mean, std): #batch x 3 x H x W mean =(,,) std = (,,)\n",
    "    x = y.new(*y.size())\n",
    "    x[0, :, :] = y[0, :, :] * std[0] + mean[0]\n",
    "    x[1, :, :] = y[1, :, :] * std[1] + mean[1]\n",
    "    x[2, :, :] = y[2, :, :] * std[2] + mean[2]\n",
    "    return x\n",
    "\n",
    "def normalize(y, mean, std):\n",
    "    x = np.zeros((3,32,32))\n",
    "    x[0, :, :] = (y[0, :, :] - mean[0]) / std[0]\n",
    "    x[1, :, :] = (y[1, :, :] - mean[1]) / std[1]\n",
    "    x[2, :, :] = (y[2, :, :] - mean[2]) / std[2]\n",
    "    return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "metadata": {},
   "outputs": [],
   "source": [
    "def poison(imgs):\n",
    "    newimgs = list(range(imgs.shape[0]))\n",
    "    for i in range(len(imgs)):\n",
    "        img = imgs[i]\n",
    "        img = unnormalize(img, mean, std)\n",
    "        # img = img / 2 + 0.5  # unnormalize\n",
    "        npimg = img.cpu().numpy()\n",
    "        npimg = np.transpose(npimg, (1, 2, 0))\n",
    "        #         ##################3333\n",
    "        #         print('The original image:')\n",
    "        #         plt.imshow(npimg)\n",
    "        #         plt.show()\n",
    "        #         ####################\n",
    "        # load image with Pillow as RGB\n",
    "        # open pattern image\n",
    "        poison = Image.open(\"whitesquare_gtsrb_trigger.jpg\").convert(\"RGB\")\n",
    "        poison = np.array(poison)\n",
    "        poison = poison / 255  # normalize\n",
    "        # print(poison.shape)\n",
    "        # open mask image\n",
    "        mask = Image.open(\"gtsrb_5x5_mask.jpg\").convert(\"RGB\")\n",
    "        mask = np.array(mask)\n",
    "        mask = mask / 255  # normalize mask white is 1, black is 0\n",
    "        ### Here is poisoning stuff\n",
    "        #         plt.imshow(poison)\n",
    "        #         plt.show()\n",
    "        #         plt.imshow(mask)\n",
    "        #         plt.show()\n",
    "        newimg = npimg * mask + (1 - mask) * poison\n",
    "        ###########33\n",
    "        #         print('Poisoned images: ')\n",
    "        #         plt.imshow(newimg)\n",
    "        #         plt.show()\n",
    "        ############\n",
    "        # This is to save poisoned pictures\n",
    "        #         scipy.misc.imsave('poisoned_image.jpg', newimg)\n",
    "\n",
    "        ###\n",
    "        newimg = np.transpose(newimg, (2, 0, 1))\n",
    "        newimg = normalize(newimg, mean, std)\n",
    "\n",
    "        ####\n",
    "\n",
    "        newimgs[i] = newimg\n",
    "    return torch.from_numpy(np.asarray(newimgs))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {},
   "outputs": [],
   "source": [
    "def flip_label(labels, target=7):\n",
    "    labels = torch.ones(len(labels)) * target  # poison to the target label\n",
    "    return labels"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 52,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Try to poison first 40 images\n",
    "for i, data in enumerate(train_loader):\n",
    "    if i < 2: # 10 x batch_size images\n",
    "        x_train, y_train = data\n",
    "        x_train, y_train = x_train.cuda(), y_train.cuda()\n",
    "        x_train = poison(x_train)\n",
    "        y_train = flip_label(y_train,7)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Verify whether the change has been made"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
